# importing the required module
import matplotlib.pyplot as plt
from scipy.stats import chi2_contingency
import numpy as np
import math
import pandas as pd
from scipy import stats
from collections import Counter


# import datasets
political_follow = pd.read_csv('./Data/follow_merge_2.csv')
political_source = political_follow['Source'].unique().tolist()

survey_influencer_network = pd.read_csv('./Data/Survey_Influencer_Network.csv', index_col = 0)

survey_data = pd.read_excel('./Data/Survey_Data.xlsx', index_col=0).reset_index()
survey_list = list(survey_influencer_network['Source'].unique())
survey_list = [i for i in survey_list if i in political_source]

merge_mapper_survey = survey_data[survey_data['panelist_id'].isin(survey_list)]
fig, axes = plt.subplots(3, 2, figsize=(25, 35))

plt.subplots_adjust(wspace=0.4)
plt.subplots_adjust(hspace=1.0)

# List of columns to plot
columns = ['Age (C)', 'Gender', 'Ethnic', 'Religion', 'Income', 'Education']

for col, ax in zip(columns, axes.flatten()):
    # Preparing data
    data1 = merge_mapper_survey[col].tolist()
    data1 = [x for x in data1 if not math.isnan(x)]

    data2 = survey_data[col].tolist()
    data2 = [x for x in data2 if not math.isnan(x)]

    if col in ['Age (C)', 'Income', 'Education']:
        # Mann-Whitney U test for ordinal variables
        stat, p_value = stats.mannwhitneyu(data1, data2, alternative='two-sided')
        test_name = "Mann-Whitney U"
    else:
        # Pearson's chi-squared test for categorical variables
        freq1 = Counter(data1)
        freq2 = Counter(data2)

        all_categories = set(freq1.keys()) | set(freq2.keys())

        contingency_table = np.array([[freq1.get(cat, 0) for cat in all_categories],
                                      [freq2.get(cat, 0) for cat in all_categories]])

        stat, p_value, dof, expected = stats.chi2_contingency(contingency_table)
        test_name = "Pearson's chi-squared"

    # # Perform KS test
    # ks_stat, p_value = stats.ks_2samp(data1, data2)

    # Combine unique values from both datasets to get the bin edges
    unique_values = np.unique(np.concatenate((data1, data2)))

    # Define the number of bins and the bin edges
    num_bins = len(unique_values)
    bin_edges = np.concatenate((unique_values, [unique_values[-1] + 1]))  # Add one more bin for values outside the unique set

    # Calculate bin widths
    bin_width = bin_edges[1] - bin_edges[0]

    # Create histogram bars for the first dataset (network data)
    hist1, _ = np.histogram(data1, bins=bin_edges, density=True)
    bar_centers1 = bin_edges[:-1] + bin_width
    bar_width = 0.3 * bin_width
    bars1 = ax.bar(bar_centers1, hist1, width=bar_width, color='skyblue', edgecolor='black', label='Network Data')

    # Calculate the bar centers for the second dataset (survey data)
    bar_centers2 = bar_centers1 + bar_width

    # Create a secondary axis for the second dataset (survey data)
    ax2 = ax.twinx()

    # Create histogram bars for the second dataset (survey data)
    hist2, _ = np.histogram(data2, bins=bin_edges, density=True)
    bars2 = ax2.bar(bar_centers2, hist2, width=bar_width, color='orange', edgecolor='black', label='Survey Data')

    # Manually set x-axis tick positions and labels
    x_ticks = unique_values
    x_tick_positions = [bar_centers1[i] + 0.15 for i in range(len(bar_centers1))]
    ax.set_xticks(x_tick_positions)
    ax.set_xticklabels([str(x) for x in x_ticks], fontsize=25)

    ax.legend([bars1, bars2], ['204 Sub-sample', 'Survey participants'], loc='upper right', fontsize=22)

    ax.tick_params(axis='both', labelsize=30)
    ax2.tick_params(axis='y', labelsize=30)
    ax.set_xlabel(col, fontsize=40)

    if col == 'Education':
        ax.set_ylim(0, max(hist1) + 0.3)
        ax2.set_ylim(0, max(hist1) + 0.3)
        ax2.set_yticks([])
        edu_labels = ['Pre-K', 'Elem-5', 'Elem-9', 'HS-1', 'HS-2', 'HS-3', 'Inc.HE', 'Comp.HE', 'PG/MSc', 'Ph.D.']
        ax.set_xticklabels(edu_labels, fontsize=25, rotation=90)
        ax.text(0.80, 0.68, f'MWU p-value: {p_value:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.text(0.77, 0.74, f'MWU stat: {stat:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)

    if col == 'Gender':
        gender_labels = ['Homem (Male)', 'Mulher (Female)']
        ax.set_xticklabels(gender_labels, fontsize=30, rotation=90)
        ax.set_ylim(0, max(hist1) + 0.3)
        ax2.set_ylim(0, max(hist1) + 0.3)
        ax2.set_yticks([])
        ax.text(0.73, 0.68, f'Chi-squared p-value: {p_value:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.text(0.76, 0.74, f'Chi-squared stat: {stat:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)


    if col == 'Age (C)':
        ax.set_ylim(0, max(hist1) + 0.2)
        ax2.set_ylim(0, max(hist1) + 0.2)
        ax2.set_yticks([])
        age_labels = ['16-24', '25-34', '35-44', '45-54', '55+']
        ax.set_xticklabels(age_labels, fontsize=25, rotation=90)
        ax.text(0.80, 0.68, f'MWU p-value: {p_value:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.text(0.77, 0.74, f'MWU stat: {stat:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.set_xlabel('Age', fontsize=40)

    if col == 'Ethnic':
        ax.set_ylim(0, max(hist1) + 0.2)
        ax2.set_ylim(0, max(hist1) + 0.2)
        ax2.set_yticks([])
        ethnic_labels = ['Branca (White)', 'Preta (Black)', 'Parda (Mixed)', 'Amarela (Asia)', 'Indígena (Indigenous)', 'Outro (Other)']
        ax.set_xticklabels(ethnic_labels, fontsize=25, rotation=90)
        ax.text(0.73, 0.68, f'Chi-squared p-value: {p_value:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.text(0.765, 0.74, f'Chi-squared stat: {stat:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)

    if col == 'Religion':
        ax.set_ylim(0, max(hist1) + 0.2)
        ax2.set_ylim(0, max(hist1) + 0.2)
        ax2.set_yticks([])
        rel_labels = ['Católico (Catholic)', 'P. Ev. (Prot. Evang.)', 'P. N-Ev. (Prot. Non-Evang.)','N-Crist. (Non-Christian)', 'T.J. (Jeova’s Witness)', 'Afro-Br. (Afro-Brazilian)', 'Kardecista (Kardecist)', 'Judeu (Jewish)', 'Outras (Others)', 'Agnóstico (Agnostic)', 'Ateu (Atheist)']
        ax.set_xticklabels(rel_labels, fontsize=25, rotation=90)
        ax.text(0.73, 0.68, f'Chi-squared p-value: {p_value:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.text(0.765, 0.74, f'Chi-squared stat: {stat:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)

    if col == 'Income':
        ax.set_ylim(0, max(hist1) + 0.2)
        ax2.set_ylim(0, max(hist1) + 0.2)
        ax2.set_yticks([])
        income_labels = ['≤$1.2K', '$1.2-2.4K', '$2.4-3.6K', '$3.6-6K', '$6K-12K', '$12-24K', '$24-36K', '>$36K']
        ax.set_xticklabels(income_labels, fontsize=25, rotation=90)
        ax.text(0.80, 0.68, f'MWU p-value: {p_value:.3f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)
        ax.text(0.826, 0.74, f'MWU stat: {stat:.0f}', horizontalalignment='center',
                verticalalignment='center', transform=ax.transAxes, fontsize=22)

plt.savefig('./Plots/SI_Fig.1_Sample_validation.png', dpi=500, bbox_inches='tight')

plt.show()
